// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef INK_GEOMETRY_INTERNAL_STATIC_RTREE_H_
#define INK_GEOMETRY_INTERNAL_STATIC_RTREE_H_

#include <cmath>
#include <cstdint>
#include <functional>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/log/absl_check.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "ink/geometry/envelope.h"
#include "ink/geometry/internal/intersects_internal.h"
#include "ink/geometry/rect.h"
#include "ink/types/small_array.h"

namespace ink::geometry_internal {

// An R-Tree containing a fixed set of unchanging elements. R-Trees are a
// spatial index that allow for fast lookups by maintaining a balanced tree
// structure (http://www-db.deis.unibo.it/courses/SI-LS/papers/Gut84.pdf).
// Because this R-Tree cannot change, it is best suited for uses in which you
// compute the set of objects once, and then do not mutate it unless you are
// completely regenerating it (e.g. stroke meshes).
//
// Template parameter `T` is the type of the data elements that are stored in
// the tree. This doesn't actually need to be the data itself, though, e.g. it
// could instead be an index or identifier for data that is stored elsewhere.
//
// Template parameter `kBranchingFactor` determines the branching factor of the
// tree. It is exposed to facilitate testing of the structure, but you should
// generally just use the default value.
//
// Implementation note: the fact that the structure of the tree does not change
// allows us to store the tree in just two contiguous sections of memory (one
// each for the branch nodes and leaf nodes), and refer to nodes by their
// position in the storage vectors (a `uint32_t`, 4 bytes), instead of using
// pointers (8 bytes). This reduces allocations, improves cache locality, and
// reduces memory footprint.
template <typename T, uint32_t kBranchingFactor = 16>
class StaticRTree {
 public:
  static_assert(kBranchingFactor >= 2,
                "kBranchingFactor must be at least 2 to form a tree");

  // A branch node in the tree, which has children but no data elements. This is
  // exposed only for testing, and should not be needed for normal use of the
  // `StaticRTree`.
  struct BranchNode {
    // The minimum bounding rectangle of all descendants of this node.
    Rect bounds;
    // Indicates the type of this node's children: if true, the children are
    // leaf nodes (stored implicitly in `elements_`; if false, they are
    // `BranchNode`s.
    bool is_leaf_parent;
    // The indices of the children of this node, which may be branch nodes or
    // leaf nodes, per `is_leaf_parent`.
    SmallArray<uint32_t, kBranchingFactor> child_indices;
  };

  // Constructs an empty `StaticRTree`. Note that, because the `StaticRTree`
  // cannot be changed after construction, you can't do much with an empty
  // `StaticRTree`.
  StaticRTree() = default;

  // Constructs a `StaticRTree` containing `elements`, using `bounds_func` to
  // compute the bounding rectangle of each element. `bounds_func` will be moved
  // into the `StaticRTree`; if it references data owned by another object, it
  // is the responsibility of the caller to ensure that that data remains valid
  // for the lifetime of the `StaticRTree`.
  //
  // This CHECK-fails if `elements` contains more than 2^32 (4294967296)
  // elements, or if `bounds_func` == nullptr.
  StaticRTree(absl::Span<const T> elements,
              std::function<Rect(const T&)> bounds_func);

  // Constructs a `StaticRTree` containing `n_elements` objects, which are
  // generated by repeatedly calling `generator`. `bounds_func` is used to
  // compute the bounding rectangle of each element, and will be moved into the
  // `StaticRTree`; if it references data owned by another object, it is the
  // responsibility of the caller to ensure that that data remains valid for the
  // lifetime of the `StaticRTree`.
  //
  // Template parameter `Generator` must have a zero-argument function call
  // operator that returns an object that is convertible to `T`, and it must be
  // valid to call `generator()` at least `n_elements` times.
  //
  // This CHECK-fails if `bounds_func` == nullptr.
  template <typename Generator>
  StaticRTree(uint32_t n_elements, Generator generator,
              std::function<Rect(const T&)> bounds_func);

  StaticRTree(const StaticRTree&) = default;
  StaticRTree(StaticRTree&&) = default;
  StaticRTree& operator=(const StaticRTree&) = default;
  StaticRTree& operator=(StaticRTree&&) = default;

  // Visits the elements whose bounding box intersects `bounds`, (per the
  // `Intersects` function). The traversal continues until `visitor` returns
  // false, at which point no more elements will be visited. The visitation
  // order depends on the structure of the tree, which should be assumed to be
  // arbitrary, and may be non-deterministic.
  //
  // For example, if you have a `StaticRTree` containing `Segment`s, and wanted
  // to find the elements that intersect a `Triangle`, you could say:
  //   Triangle t;
  //   std::vector<Segment> output;
  //   rtree.VisitIntersectedElements(*Envelope(t).AsRect(),
  //       [&t](const Segment& s) {
  //         if (Intersects(t, s)) output.push_back(s);
  //         return true;
  //       });
  void VisitIntersectedElements(
      const Rect& bounds, absl::FunctionRef<bool(const T&)> visitor) const;

  absl::Span<const BranchNode> BranchNodes() const { return branch_nodes_; }
  absl::Span<const T> Elements() const { return elements_; }

 private:
  // Initializes the structure of the tree, populating `branch_nodes_`. If
  // `elements_` is empty, this is a no-op, as there is nothing to put in the
  // tree. CHECK-fails if `bounds_func_` == nullptr.
  void InitializeTree();

  // This is a helper method for `VisitIntersectedElements`, which visits the
  // sub-tree whose root is the branch node at index `sub_tree_root_index`. This
  // returns `true` if the traversal should continue, or `false` if it should
  // stop early.
  bool VisitIntersectedElementsInSubTree(
      uint32_t sub_tree_root_idx, const Rect& bounds,
      absl::FunctionRef<bool(const T&)> visitor) const;

  // The branch nodes are stored such that each node has a depth at least as
  // great as the previous one. This means that the root will always be the
  // first element, and that all branch nodes of a particular depth occupy a
  // contiguous span of `branch_nodes`.
  std::vector<BranchNode> branch_nodes_;

  // The data elements in the R-Tree, which are also the leaf nodes (since they
  // contain no other information).
  std::vector<T> elements_;

  std::function<Rect(const T&)> bounds_func_;
};

// -----------------------------------------------------------------------------
//                     Implementation details below

// This is the maximum expected number of branch levels in the R-Tree. The
// number of branch levels in the R-Tree is equal to log(N) / log(B), where N is
// the number of elements and B is the branching factor. We get an expected
// maximum of  8 branch levels from the maximum number of elements (2^32) and
// the default branching factor of 16.
inline constexpr int kMaxExpectedRTreeBranchDepth = 8;

// Helper function for `StaticRTree` ctor, computes the number of branch nodes
// required at each depth for `n_leaf_nodes` and `branching_factor`. The output
// is in order of descending depth, i.e. the first element is the number of
// branch nodes at the root level (which will always be 1), and the last element
// is the number of branch nodes at the level above the leaf nodes.
absl::InlinedVector<uint32_t, kMaxExpectedRTreeBranchDepth>
ComputeNumberOfRTreeBranchNodesAtDepth(uint32_t n_leaf_nodes,
                                       uint32_t branching_factor);

// Helper function for the `StaticRTree` ctor, computes the offsets in
// `StaticRTree::branch_nodes_` at which contiguous span of same-depth nodes
// begins. `n_branch_nodes_at_depth` is expected to be the output from
// `ComputeNumberOfRTreeBranchNodesAtDepth`, and should contain the number of
// branch nodes at each depth, in descending order.
absl::InlinedVector<uint32_t, kMaxExpectedRTreeBranchDepth>
ComputeRTreeBranchDepthOffsets(
    absl::Span<const uint32_t> n_branch_nodes_at_depth);

// Helper function for the `StaticRTree` ctor, bulk-loads one level of the
// R-Tree using the Sort-Tile-Recursive algorithm
// (www.cs.odu.edu/~mln/ltrs-pdfs/icase-1997-14.pdf),
//
// `index_of_first_parent_node` and `index_of_first_child_node` contain the
// indices of the first parent and child nodes in their respective containers;
// the parent nodes are always branch nodes, but the child nodes may be either
// leaf nodes (for the first pass) or branch nodes (for subsequent passes).
// `n_parent_nodes` and `n_child_nodes` indicate the quantity of each.
//
// Template parameter `ChildBoundsGetter` should be a functor of the form:
//   Rect Foo(uint32_t child_index)
// which returns the bounding rectangle for the child node at `child_index`.
//
// Template parameter `ParentAssigner` should be a functor of the form:
//   void Bar(uint32_t parent_index, const Rect& child_bounds,
//            absl::Span<const uint32_t> child_indices)
// which sets any fields necessary to make the parent node at `parent_index` the
// parent of the child nodes at `child_indices`. `child_bounds` will be the
// rectangle that contains all of the child nodes referred to by
// `child_indices`.
template <typename ChildBoundsGetter, typename ParentAssigner>
void BulkLoadOneLevelOfNodes(uint32_t index_of_first_parent_node,
                             uint32_t n_parent_nodes,
                             uint32_t index_of_first_child_node,
                             uint32_t n_child_nodes,
                             ChildBoundsGetter get_child_bounds,
                             ParentAssigner assign_children_to_parent,
                             int branching_factor) {
  // These should be guaranteed by the logic in the ctor.
  ABSL_DCHECK_GT(n_child_nodes, 0u);
  ABSL_DCHECK_EQ(n_parent_nodes, std::ceil(static_cast<double>(n_child_nodes) /
                                           branching_factor));

  // We bulk-load each level using the Sort-Tile-Recursive algorithm
  // (www.cs.odu.edu/~mln/ltrs-pdfs/icase-1997-14.pdf),
  //
  // which determines how
  // many parent nodes are needed, then attempts to group the child nodes such
  // that groups approximate a k-by-k grid, s.t. k = ceil(sqrt(N / B)), N =
  // `n_child_nodes`, and B = `branching_factor`.

  auto compare_child_nodes_by_x = [&get_child_bounds](uint32_t lhs_idx,
                                                      uint32_t rhs_idx) {
    return get_child_bounds(lhs_idx).XMin() < get_child_bounds(rhs_idx).XMin();
  };
  auto compare_child_nodes_by_y = [&get_child_bounds](uint32_t lhs_idx,
                                                      uint32_t rhs_idx) {
    return get_child_bounds(lhs_idx).YMin() < get_child_bounds(rhs_idx).YMin();
  };

  // Instead of sorting the nodes themselves, which would invalidate any
  // references to them by index, we make a list of the indices, and sort those.
  std::vector<uint32_t> sortable_child_indices(n_child_nodes);
  absl::c_iota(sortable_child_indices, index_of_first_child_node);

  uint32_t n_tiles =
      std::ceil(static_cast<double>(n_child_nodes) / branching_factor);
  uint32_t base_slice_size =
      std::ceil(std::sqrt(static_cast<double>(n_tiles))) * branching_factor;
  uint32_t n_slices =
      std::ceil(static_cast<double>(n_child_nodes) / base_slice_size);

  // Sort the child nodes by x-coordinate, so that they can be divided
  // vertically to form slices.
  absl::c_sort(sortable_child_indices, compare_child_nodes_by_x);
  uint32_t parent_idx = index_of_first_parent_node;
  for (uint32_t slice_idx = 0; slice_idx < n_slices; ++slice_idx) {
    uint32_t first_child_in_slice = slice_idx * base_slice_size;
    // The last slice will have fewer children if `n_child_nodes` is not a
    // multiple of `base_slice_size`.
    uint32_t n_children_in_slice = slice_idx == n_slices - 1
                                       ? n_child_nodes - first_child_in_slice
                                       : base_slice_size;
    uint32_t n_tiles_in_slice =
        std::ceil(static_cast<double>(n_children_in_slice) / branching_factor);

    // Sort the child nodes in this slice by y-coordinate, so that they can be
    // divided horizontally to form tiles.
    auto slice_child_indices = absl::MakeSpan(
        &sortable_child_indices[first_child_in_slice], n_children_in_slice);
    absl::c_sort(slice_child_indices, compare_child_nodes_by_y);
    for (uint32_t tile_idx = 0; tile_idx < n_tiles_in_slice; ++tile_idx) {
      uint32_t first_child_in_tile = tile_idx * branching_factor;
      // The last tile will have fewer children if `n_children_in_slice` is not
      // a multiple of `branching_factor`.
      uint32_t n_children_in_tile =
          tile_idx == n_tiles - 1 ? n_children_in_slice - first_child_in_tile
                                  : branching_factor;

      absl::Span<const uint32_t> tile_child_indices =
          slice_child_indices.subspan(first_child_in_tile, n_children_in_tile);
      Envelope envelope;
      for (uint32_t child_idx : tile_child_indices) {
        envelope.Add(get_child_bounds(child_idx));
      }

      // This should be guaranteed by the logic in the ctor.
      ABSL_DCHECK_LT(parent_idx, index_of_first_parent_node + n_parent_nodes);
      assign_children_to_parent(parent_idx, *envelope.AsRect(),
                                tile_child_indices);
      ++parent_idx;
    }
  }
}

template <typename T, uint32_t kBranchingFactor>
StaticRTree<T, kBranchingFactor>::StaticRTree(
    absl::Span<const T> elements, std::function<Rect(const T&)> bounds_func)
    : elements_(elements.begin(), elements.end()),
      bounds_func_(std::move(bounds_func)) {
  ABSL_CHECK_LE(elements.size(), uint64_t{1} << 32) << absl::Substitute(
      "StaticRTree supports a maximum of 2^32 (4294967296) elements; $0 were "
      "given",
      elements.size());
  InitializeTree();
}

template <typename T, uint32_t kBranchingFactor>
template <typename Generator>
StaticRTree<T, kBranchingFactor>::StaticRTree(
    uint32_t n_elements, Generator generator,
    std::function<Rect(const T&)> bounds_func)
    : bounds_func_(std::move(bounds_func)) {
  elements_.resize(n_elements);
  absl::c_generate(elements_, generator);
  InitializeTree();
}

template <typename T, uint32_t kBranchingFactor>
void StaticRTree<T, kBranchingFactor>::InitializeTree() {
  if (elements_.empty()) {
    // This is an empty R-Tree, there is nothing to initialize.
    return;
  }

  ABSL_CHECK(bounds_func_ != nullptr) << "bounds_func must be non-null";

  absl::InlinedVector<uint32_t, kMaxExpectedRTreeBranchDepth>
      n_branch_nodes_at_depth = ComputeNumberOfRTreeBranchNodesAtDepth(
          elements_.size(), kBranchingFactor);
  absl::InlinedVector<uint32_t, kMaxExpectedRTreeBranchDepth>
      branch_depth_offsets =
          ComputeRTreeBranchDepthOffsets(n_branch_nodes_at_depth);

  branch_nodes_.resize(branch_depth_offsets.back() +
                       n_branch_nodes_at_depth.back());

  std::vector<Rect> leaf_bounds(elements_.size());
  absl::c_transform(elements_, leaf_bounds.begin(), bounds_func_);

  auto get_leaf_bounds = [bounds = std::move(leaf_bounds)](uint32_t idx) {
    return bounds[idx];
  };
  auto get_branch_bounds = [this](uint32_t idx) {
    return branch_nodes_[idx].bounds;
  };

  auto assign_leaf_children_to_parent =
      [this](uint32_t parent_idx, const Rect& child_bounds,
             absl::Span<const uint32_t> child_indices) {
        BranchNode& parent = branch_nodes_[parent_idx];
        parent.bounds = child_bounds;
        parent.is_leaf_parent = true;
        parent.child_indices =
            SmallArray<uint32_t, kBranchingFactor>(child_indices);
      };
  auto assign_branch_children_to_parent =
      [this](uint32_t parent_idx, const Rect& child_bounds,
             absl::Span<const uint32_t> child_indices) {
        BranchNode& parent = branch_nodes_[parent_idx];
        parent.bounds = child_bounds;
        parent.is_leaf_parent = false;
        parent.child_indices =
            SmallArray<uint32_t, kBranchingFactor>(child_indices);
      };

  BulkLoadOneLevelOfNodes(
      branch_depth_offsets.back(), n_branch_nodes_at_depth.back(),
      /* index_of_first_child_node = */ 0, elements_.size(), get_leaf_bounds,
      assign_leaf_children_to_parent, kBranchingFactor);

  for (int depth = n_branch_nodes_at_depth.size() - 2; depth >= 0; --depth) {
    BulkLoadOneLevelOfNodes(
        branch_depth_offsets[depth], n_branch_nodes_at_depth[depth],
        branch_depth_offsets[depth + 1], n_branch_nodes_at_depth[depth + 1],
        get_branch_bounds, assign_branch_children_to_parent, kBranchingFactor);
  }
}

template <typename T, uint32_t kBranchingFactor>
void StaticRTree<T, kBranchingFactor>::VisitIntersectedElements(
    const Rect& bounds, absl::FunctionRef<bool(const T&)> visitor) const {
  if (!IntersectsInternal(branch_nodes_.front().bounds, bounds)) return;
  VisitIntersectedElementsInSubTree(0, bounds, visitor);
}

template <typename T, uint32_t kBranchingFactor>
bool StaticRTree<T, kBranchingFactor>::VisitIntersectedElementsInSubTree(
    uint32_t sub_tree_root_idx, const Rect& bounds,
    absl::FunctionRef<bool(const T&)> visitor) const {
  const BranchNode& node = branch_nodes_[sub_tree_root_idx];
  if (node.is_leaf_parent) {
    for (uint32_t leaf_idx : node.child_indices.Values()) {
      if (IntersectsInternal(bounds_func_(elements_[leaf_idx]), bounds) &&
          !visitor(elements_[leaf_idx])) {
        return false;
      }
    }
  } else {
    for (uint32_t branch_idx : node.child_indices.Values()) {
      if (IntersectsInternal(branch_nodes_[branch_idx].bounds, bounds) &&
          !VisitIntersectedElementsInSubTree(branch_idx, bounds, visitor)) {
        return false;
      }
    }
  }
  return true;
}

}  // namespace ink::geometry_internal

#endif  // INK_GEOMETRY_INTERNAL_STATIC_RTREE_H_
